package org.springblade.ocpp;

import cn.hutool.core.codec.Base64;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import jakarta.websocket.*;
import jakarta.websocket.server.PathParam;
import jakarta.websocket.server.ServerEndpoint;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.springblade.common.utils.LogFileName;
import org.springblade.common.utils.LoggerUtils;
import org.springblade.core.log.exception.ServiceException;
import org.springblade.modules.charger.pojo.entity.ChargerEntity;
import org.springblade.modules.charger.service.IChargerService;
import org.springblade.ocpp.conf.WebSocketHeaderConfig;
import org.springblade.ocpp.domain.req.BasePayload;
import org.springblade.ocpp.handle.ServerCoreEventHandler;
import org.springblade.ocpp.util.DataAnalysisUtil;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Component
@ServerEndpoint(value = "/server/{sid}", configurator = WebSocketHeaderConfig.class)
public class WebSocketServer {

	private static final Logger log = LoggerUtils.logger(LogFileName.CHARGER);

	/**
	 * 记录当前在线连接数
	 */
	public static int onlineCount = 0;

	/**
	 * 使用线程安全的ConcurrentHashMap来存放每个客户端对应的WebSocket对象
	 */
	public static ConcurrentHashMap<String, WebSocketServer> webSocketMap = new ConcurrentHashMap<>();

	/**
	 * 与某个客户端的连接会话，需要通过它来给客户端发送数据
	 */
	private Session session;

	/**
	 * 接收客户端消息的uid
	 */
	private String sid = "";

	/**
	 * 连接建立成功调用的方法
	 *
	 * @param session
	 * @param sid
	 */
	@OnOpen
	public void onOpen(Session session, @PathParam("sid") String sid) {
		// TODO 先联调暂不考虑鉴权
		//this.authentication(session);

		this.session = session;
		this.sid = sid;
		if (webSocketMap.containsKey(sid)) {
			webSocketMap.remove(sid);
			//加入到set中
			webSocketMap.put(sid, this);
		} else {
			//加入set中
			webSocketMap.put(sid, this);
			//在线数加1
			addOnlineCount();
		}

		// 更新设备状态
		this.editDeviceStatus(sid, 2);
		log.info("[OCPP] 设备【" + sid + "】连接成功，当前在线人数为:" + getOnlineCount());
	}

	/**
	 * 连接关闭调用的方法
	 */
	@OnClose
	public void onClose() {
		if (webSocketMap.containsKey(sid)) {
			webSocketMap.remove(sid);
			//从set中删除
			subOnlineCount();
		}
		// 更新设备状态
		this.editDeviceStatus(sid, 1);
		log.info("[OCPP] 设备【" + sid + "】退出，当前在线人数为:" + getOnlineCount());
	}


	private void editDeviceStatus(String sid, int status) {
		ThreadUtil.execute(() -> {
			IChargerService chargerService = SpringUtil.getBean(IChargerService.class);
			chargerService.update(Wrappers.<ChargerEntity>lambdaUpdate()
				.eq(ChargerEntity::getChargerSn, sid)
				.set(ChargerEntity::getStatus, status)
			);
		});
	}

	/**
	 * 收到客户端消息后调用的方法
	 *
	 * @param message 客户端发送过来的消息
	 * @param session 会话
	 */
	@OnMessage
	public void onMessage(String message, Session session) {
		log.info("[OCPP] 接收设备【 {} 】的报文: {}", sid, message);
		try {
			// 消息是否有值
			if (StrUtil.isBlank(message)) {
				throw new ServiceException("The message cannot be empty");
			}
			// 数据格式不正确
			if (!JSONUtil.isTypeJSONArray(message)) {
				throw new ServiceException("The data format is incorrect");
			}
			// 校验消息中的类型是否包含协议中的类型
			BasePayload basePayload = DataAnalysisUtil.reqDataAnalysis(message, sid);
			log.info("[OCPP] onMessage basePayload:{}", JSONUtil.toJsonStr(basePayload));
			if (null == basePayload) {
				return;
			}
			// 处理设备的消息并回复
			ServerCoreEventHandler serverCoreEventHandler = SpringUtil.getBean(ServerCoreEventHandler.class);
			String res = DataAnalysisUtil.resDataAnalysis(serverCoreEventHandler, basePayload);
			log.info("[OCPP] 响应设备【 {} 】的消息:{}", sid, res);
			// 回复消息
			this.sendMsg(res);
		} catch (Exception e) {
			log.error("处理消息异常", e);
			// 回复异常消息

		}
	}

	/**
	 * 处理错误
	 *
	 * @param session
	 * @param error
	 */
	@OnError
	public void onError(Session session, Throwable error) {
		log.error("[OCPP] 设备【" + this.sid + "】处理消息错误，原因:" + error.getMessage());
		log.error("[OCPP] 处理错误", error);
	}

	/**
	 * 推送信息
	 *
	 * @param msg
	 * @throws IOException
	 */
	private void sendMsg(String msg) {
		if (StringUtils.isEmpty(msg)) {
			throw new ServiceException("The message cannot be empty");
		}
		if (!this.session.isOpen()) {
			log.warn("[OCPP] {} 不在线, 消息:{}", this.sid, msg);
			throw new ServiceException("Device offline");
		}
		try {
			this.session.getBasicRemote().sendText(msg);
		} catch (IOException e) {
			log.error("sendMsg", e);
		}
	}

	/**
	 * 推送信息
	 *
	 * @param session
	 * @param msg
	 * @throws IOException
	 */
	private void sendMsg(Session session, String msg) throws IOException {
		if (StringUtils.isEmpty(msg)) {
			return;
		}
		if (!session.isOpen()) {
			log.warn("[OCPP] 不在线, 消息:{}", msg);
			return;
		}
		session.getBasicRemote().sendText(msg);
	}

	/**
	 * 发送自定义消息
	 *
	 * @param message
	 * @param sid
	 * @throws IOException
	 */
	public static void sendInfo(String message, String sid) {
		log.info("[OCPP] 发送消息到设备【 {} 】报文: {}", sid, message);
		if (!StringUtils.isEmpty(sid) && webSocketMap.containsKey(sid)) {
			webSocketMap.get(sid).sendMsg(message);
		} else {
			log.error("[OCPP] 设备【" + sid + "】不在线!");
			throw new ServiceException(StrUtil.format("设备{}已离线", sid));
		}
	}

	public static void sendMsgAll(String message) throws IOException {
		log.debug("[OCPP] 发送消息到所有设备发送的报文:" + message);
		if (StringUtils.isEmpty(message)) {
			return;
		}
		for (Map.Entry<String, WebSocketServer> stringWebSocketServerEntry : webSocketMap.entrySet()) {
			WebSocketServer webSocketServer = stringWebSocketServerEntry.getValue();
			webSocketServer.sendMsg(message);
		}
	}

	private static synchronized int getOnlineCount() {
		return onlineCount;
	}

	private static synchronized void addOnlineCount() {
		WebSocketServer.onlineCount++;
	}

	private static synchronized void subOnlineCount() {
		if (WebSocketServer.onlineCount > 0) {
			WebSocketServer.onlineCount--;
		}
	}

	private static String getHeader(Session session, String headerName) {
		return (String) session.getUserProperties().get(headerName);
	}

	/**
	 * 鉴权
	 *
	 * @param session
	 * @return
	 */
	public void authentication(Session session) {
		String headerName = "Authorzation";
		String header = getHeader(session, headerName);
		try {
			if (StrUtil.isBlank(header)) {
				log.error("获取header失败，不安全的链接，即将关闭");
				session.close();
			}

			// 校验header中的Token是否正确
			if (header.startsWith("Basic ")) {
				header = header.replaceAll("Basic ", "");
			}

			String userNamePassword = Base64.decodeStr(header);
			log.info("鉴权Authorzation:{} 明文:{}", header, userNamePassword);
			// TODO 鉴权
			if (!"test:test".equals(userNamePassword)) {
				// TODO 错误信息
				JSONObject statusNotification = new JSONObject();
				statusNotification.putOpt("errorCode", "Authentication Fail");
				session.getBasicRemote().sendText("[4,123,StatusNotification," + statusNotification + "]");
				session.close();
			}
		} catch (Exception e) {
			log.error("获取头异常", e);
		}
	}


}
